package cc.shacocloud.mirage.bean.impl;

import cc.shacocloud.mirage.bean.ClassScanner;
import cc.shacocloud.mirage.utils.ClassUtil;
import cc.shacocloud.mirage.utils.Utils;
import cc.shacocloud.mirage.utils.charSequence.StrUtil;
import cc.shacocloud.mirage.utils.collection.ArrayUtil;
import cc.shacocloud.mirage.utils.map.ConcurrentReferenceHashMap;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.URL;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarInputStream;

/**
 * 通过扫描类路径来查找类。在文件系统和 jar 文件中搜索类
 *
 * @author 思追(shaco)
 */
@Slf4j
public class ClassPathScanner implements ClassScanner {
    
    private final ClassLoader classLoader;
    
    // 扫描路径缓存
    private final Map<String, Set<Class<?>>> scanPathCache = new ConcurrentReferenceHashMap<>();
    
    /**
     * 使用线程上下文类装入器或此类的类装入器创建新的类路径扫描程序
     */
    public ClassPathScanner() {
        this.classLoader = ClassUtil.getDefaultClassLoader();
    }
    
    /**
     * 使用指定的类装入器创建新的类路径扫描程序。
     *
     * @param classLoader 用于扫描类并装入类的类加载器
     */
    public ClassPathScanner(@NotNull ClassLoader classLoader) {
        this.classLoader = classLoader;
    }
    
    @Override
    public ClassLoader getClassLoader() {
        return classLoader;
    }
    
    /**
     * {@inheritDoc}
     */
    @Override
    public Set<Class<?>> findClasses(String... basePackages) {
        if (ArrayUtil.isEmpty(basePackages)) {
            throw new IllegalArgumentException("basePackages 不可以为空");
        }
        
        try {
            long start = log.isDebugEnabled() ? System.currentTimeMillis() : 0;
            
            Set<Class<?>> classes = doFindClasses(basePackages);
            
            if (log.isDebugEnabled()) {
                long stop = System.currentTimeMillis();
                log.debug(String.format("扫描类路径 %s 共花费 %s ms，扫描到 %s 个类", Utils.nullSafeToString(basePackages),
                        stop - start, classes.size()));
            }
            
            return classes;
        } catch (Exception e) {
            throw new RuntimeException(String.format("扫描类路径 %s 发生例外！", Utils.nullSafeToString(basePackages)), e);
        }
        
    }
    
    /**
     * 扫描指定的基础路径
     */
    private @NotNull Set<Class<?>> doFindClasses(String @NotNull ... basePackages) throws Exception {
        Set<Class<?>> classes = new HashSet<>();
        
        for (String basePackage : basePackages) {
            
            // 过滤无效路径
            if (StrUtil.isNotBlank(basePackage)) {
                classes.addAll(findClassesFromBasePackage(basePackage));
            }
        }
        
        return classes;
    }
    
    /**
     * 扫描基础路径上所有的类
     */
    protected @NotNull Set<Class<?>> findClassesFromBasePackage(@NotNull String basePackage) throws Exception {
        return scanPathCache.computeIfAbsent(basePackage, k -> {
            
            final Set<Class<?>> classes = new HashSet<>();
            final String path = k.replace('.', '/');
            
            final Enumeration<URL> resources;
            try {
                resources = classLoader.getResources(path);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            
            if (resources != null) {
                AccessController.doPrivileged((PrivilegedAction<Object>) () -> {
                    while (resources.hasMoreElements()) {
                        URL resource = resources.nextElement();
                        try {
                            classes.addAll(getClassesFromResource(k, path, resource));
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    return null;
                });
            }
            
            return classes;
        });
    }
    
    /**
     * 从指定 {@code resource}  中获取指定类
     *
     * @param basePackage 基础包名
     * @param path        {@code basePackage} 的文件路径格式
     * @param resource    资源对象
     */
    @NotNull
    protected Set<Class<?>> getClassesFromResource(@NotNull String basePackage,
                                                   @NotNull String path,
                                                   @NotNull URL resource) throws IOException, ClassNotFoundException {
        final Set<Class<?>> classes = new HashSet<>();
        final String filePath = getFilePath(resource);
        
        if (filePath != null) {
            // 如果是 jar 文件路径
            if (isJarFilePath(filePath)) {
                final String jarPath = getJarPath(filePath);
                classes.addAll(getFromJARFile(jarPath, path));
            } else {
                classes.addAll(getFromDirectory(new File(filePath), basePackage));
            }
        }
        
        return classes;
    }
    
    /**
     * 从指定的文件目录获取
     *
     * @param directory   文件对象
     * @param packageName 基础包名
     */
    @NotNull
    private Set<Class<?>> getFromDirectory(@NotNull File directory, @NotNull String packageName) throws ClassNotFoundException {
        Set<Class<?>> classes = new HashSet<>();
        
        if (directory.exists()) {
            File[] files = directory.listFiles();
            
            if (ArrayUtil.isNotEmpty(files)) {
                for (File file : files) {
                    if (file.isDirectory()) {
                        classes.addAll(getFromDirectory(file, packageName + "." + file.getName()));
                    } else if (isClass(file.getName())) {
                        final String className = packageName + '.' + stripFileExtension(file.getName());
                        final Class<?> clazz = loadClass(className);
                        addClass(clazz, classes);
                    }
                }
            }
        }
        
        return classes;
    }
    
    /**
     * 从jar文件中获取指定包路径下的所有类
     */
    @NotNull
    private Set<Class<?>> getFromJARFile(@NotNull String jarPath, @NotNull String packageName) throws IOException, ClassNotFoundException {
        Set<Class<?>> classes = new HashSet<>();
        
        try (JarInputStream jarFile = new JarInputStream(new FileInputStream(jarPath))) {
            JarEntry jarEntry;
            
            do {
                jarEntry = jarFile.getNextJarEntry();
                if (jarEntry != null) {
                    String fileName = jarEntry.getName();
                    
                    if (isClass(fileName)) {
                        // 打包插件将主类jar的类文件放在该位置，故这边将它移除
                        fileName = StrUtil.removePrefix(fileName, "BOOT-INF/classes/");
                        
                        String className = stripFileExtension(fileName);
                        
                        if (className.startsWith(packageName)) {
                            Class<?> clazz = loadClass(className.replace('/', '.'));
                            addClass(clazz, classes);
                            
                        }
                    }
                }
            } while (jarEntry != null);
        }
        
        return classes;
    }
    
    /**
     * 获取 {@link URL} 对应的文件路径
     */
    @Nullable
    private String getFilePath(@NotNull URL url) {
        String filePath = url.getFile();
        
        if (filePath != null) {
            return fixWindowsSpace(filePath);
        }
        
        return null;
    }
    
    /**
     * 判断是否为 jar 包中的文件路径
     */
    private boolean isJarFilePath(final @NotNull String filePath) {
        return (filePath.indexOf("!") > 0) && (filePath.indexOf(".jar") > 0);
    }
    
    /**
     * 修复 windows 下的文件目录空格问题
     */
    @NotNull
    private String fixWindowsSpace(@NotNull String filePath) {
        if (filePath.indexOf("%20") > 0) {
            return filePath.replaceAll("%20", " ");
        }
        
        return filePath;
    }
    
    /**
     * 提取jar路径
     */
    private @NotNull String getJarPath(@NotNull String filePath) {
        final String jarPath = filePath.substring(0, filePath.indexOf("!")).substring(filePath.indexOf(":") + 1);
        return fixWindowsJarPath(jarPath);
    }
    
    /**
     * 修复 windows 下的 jar 路径的 冒号问题
     */
    private @NotNull String fixWindowsJarPath(@NotNull String jarPath) {
        if (jarPath.contains(":")) {
            return jarPath.substring(1);
        }
        
        return jarPath;
    }
    
    /**
     * 跳过文件扩展名
     */
    private String stripFileExtension(final String filename) {
        if (filename == null) {
            return null;
        }
        
        final int dotIndex = filename.lastIndexOf(".");
        
        if (dotIndex == -1) {
            return filename;
        }
        
        return filename.substring(0, dotIndex);
    }
    
    /**
     * 判断该文件是否为 class 文件
     */
    @Contract(pure = true)
    private boolean isClass(@NotNull String fileName) {
        return fileName.endsWith(".class");
    }
    
    /**
     * 添加 class 文件，如果类不是普通类则跳过
     */
    private void addClass(@NotNull Class<?> clazz, @NotNull Set<Class<?>> classes) {
        // 过滤枚举，注解，抽象类，匿名类，编译器生成的类
        if (!clazz.isAnonymousClass() && !clazz.isAnnotation() && !clazz.isEnum() && !clazz.isSynthetic()) {
            classes.add(clazz);
        }
    }
    
    /**
     * 加载 class
     */
    private Class<?> loadClass(String className) throws ClassNotFoundException {
        return ClassUtil.forName(className, classLoader);
    }
}
